{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('.../BELLE/train/src')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import MultiClient\n",
    "ip = '127.0.0.1'\n",
    "base_port = 17860\n",
    "worker_addrs = [\n",
    "    f\"http://{ip}:{base_port + i}\" for i in range(8)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "GENERATION_CONFIG = [\n",
    "    # int | float (numeric value between 0 and 1) in 'Temperature' Slider component\n",
    "    0.9,\n",
    "    # int | float (numeric value between 0 and 1) in 'Top p' Slider component\n",
    "    0.6,\n",
    "    # int | float (numeric value between 0 and 100) in 'Top k' Slider component\n",
    "    30,\n",
    "    # int | float (numeric value between 1 and 4) in 'Beams Number' Slider component\n",
    "    1,\n",
    "    # do sample\n",
    "    True,\n",
    "    # int | float (numeric value between 1 and 2000) in 'Max New Tokens' Slider component\n",
    "    128,\n",
    "    # int | float (numeric value between 1 and 300) in 'Min New Tokens' Slider component\n",
    "    1,\n",
    "    # int | float (numeric value between 1.0 and 2.0) in 'Repetition Penalty' Slider component\n",
    "    1.2,\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ZeRO Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "data = pd.read_json('.../BELLE/data/test_data/test_infer.jsonl', lines=True)\n",
    "data = data['text'].tolist()\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "client = MultiClient(worker_addrs, synced_worker=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \\\n",
    "\"\"\"Human: \n",
    "{text}\n",
    "\n",
    "Assistant: \n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "GENERATION_CONFIG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = []\n",
    "for i in range(len(data)):\n",
    "    tasks.append([template.format(text=data[i])] + GENERATION_CONFIG)\n",
    "answers = client.predict(tasks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "answers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 普通多进程并行推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "data = pd.read_json('.../BELLE/data/test_data/test_infer.jsonl', lines=True)\n",
    "data = data['text'].tolist()\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "client = MultiClient(worker_addrs, synced_worker=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \\\n",
    "\"\"\"Human: \n",
    "{text}\n",
    "\n",
    "Assistant: \n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "GENERATION_CONFIG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = []\n",
    "for i in range(len(data)):\n",
    "    tasks.append([template.format(text=data[i])] + GENERATION_CONFIG)\n",
    "answers = client.predict(tasks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "answers"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
